Skip to content

Commit e9ff20e

Browse files
committed
support for image attachments when classifying questions
1 parent ab1cdd3 commit e9ff20e

File tree

2 files changed

+40
-5
lines changed

2 files changed

+40
-5
lines changed

kitsune/llm/questions/classifiers.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@ def classify_question(question: "Question") -> dict[str, Any]:
3232
payload: dict[str, Any] = {
3333
"subject": question.title,
3434
"question": question.content,
35+
"image_urls": [image.get_absolute_url() for image in question.get_images()],
3536
"product": product,
3637
"topics": get_taxonomy(
3738
product, include_metadata=["description", "examples"], output_format="JSON"

kitsune/llm/questions/prompt.py

Lines changed: 39 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
from langchain.output_parsers import ResponseSchema, StructuredOutputParser
2-
from langchain.prompts import ChatPromptTemplate
2+
from langchain.prompts import ChatPromptTemplate, MessagesPlaceholder
3+
from langchain.schema import HumanMessage
4+
from langchain.schema.runnable import RunnableLambda
35

46
SPAM_INSTRUCTIONS = """
57
# Role and goal
@@ -120,17 +122,49 @@
120122
)
121123

122124

123-
spam_prompt = ChatPromptTemplate(
125+
spam_prompt_template = ChatPromptTemplate(
124126
(
125127
("system", SPAM_INSTRUCTIONS),
126-
("human", USER_QUESTION),
128+
MessagesPlaceholder("human_message"),
127129
)
128130
).partial(format_instructions=spam_parser.get_format_instructions())
129131

130132

131-
topic_prompt = ChatPromptTemplate(
133+
topic_prompt_template = ChatPromptTemplate(
132134
(
133135
("system", TOPIC_INSTRUCTIONS),
134-
("human", USER_QUESTION),
136+
("human_message"),
135137
)
136138
).partial(format_instructions=topic_parser.get_format_instructions())
139+
140+
141+
def create_human_message(inputs: dict) -> dict:
142+
"""
143+
Creates the human message, with the image URL's if they're present, and
144+
then adds it to the inputs dict. Returns the modified inputs dict.
145+
"""
146+
content: list[dict] = [
147+
{
148+
"type": "text",
149+
"text": USER_QUESTION.format(**inputs),
150+
},
151+
]
152+
153+
for image_url in inputs.get("image_urls", ()):
154+
content.append(
155+
{
156+
"type": "image_url",
157+
"image_url": {
158+
"url": image_url,
159+
},
160+
}
161+
)
162+
163+
inputs["human_message"] = [HumanMessage(content=content)]
164+
return inputs
165+
166+
167+
spam_prompt = RunnableLambda(create_human_message) | spam_prompt_template
168+
169+
170+
topic_prompt = RunnableLambda(create_human_message) | topic_prompt_template

0 commit comments

Comments
 (0)